#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read(){
	int x=0;bool f=0;char ch=getchar();
	while(ch<'0'||ch>'9')f^=(ch=='-'),ch=getchar();
	while('0'<=ch&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return f?-x:x;
}
const int Maxn=1e7+3,mod=998244353;
bool st;
int n,K,fac;
inline int ksm(int a,int k){
	k%=mod-1;
	int res=1;
	for(;k;k>>=1,a=a*a%mod)
		if(k&1)res=res*a%mod;
	return res;
}
inline void mul(int&x,int y){x+=y;if(x>=mod)x-=mod;}
bool en;
signed main(){
	freopen("length.in","r",stdin);
	freopen("length.out","w",stdout);
//	ios::sync_with_stdio(0);
//	cin.tie(0);cout.tie(0);
	n=read();K=read();
	fac=ksm(n,n);
	int ans=0;int invn=ksm(n,mod-2);
	for(int i=1;i<=n;i++){
		mul(ans,fac*ksm(i,K-1)%mod);
		fac=fac*(n-i)%mod*invn%mod;
	}printf("%lld\n",ans);
	return 0;
}
